// Copyright (c) 2023, Cogent Core. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.

// Package generate provides utilities for building code generators in Go.
// The standard path for a code generator is: [Load] -> [PrintHeader] -> [Inspect] -> [Write].
package generate

import (
	"fmt"
	"go/ast"
	"io"
	"os"
	"path/filepath"
	"strings"

	"golang.org/x/tools/go/packages"
	"golang.org/x/tools/imports"
)

// Load loads and returns the Go packages named by the given patterns.
// Load calls [packages.Load] and ensures that there is at least one
// package; this means that, if there is a nil error, the length
// of the resulting packages is guaranteed to be greater than zero.
func Load(cfg *packages.Config, patterns ...string) ([]*packages.Package, error) {
	pkgs, err := packages.Load(cfg, patterns...)
	if err != nil {
		return nil, err
	}
	if len(pkgs) == 0 {
		return nil, fmt.Errorf("expected at least one package but got %d", len(pkgs))
	}
	return pkgs, nil
}

// PrintHeader prints a header to the given writer for a generated
// file in the given package with the given imports. Imports do not
// need to be set if you are running [Format] on the code later,
// but they should be set for any external packages that many not
// be found correctly by goimports.
func PrintHeader(w io.Writer, pkg string, imports ...string) {
	cmdstr := strings.TrimSuffix(filepath.Base(os.Args[0]), ".exe")
	if len(os.Args) > 1 {
		cmdstr += " " + strings.Join(os.Args[1:], " ")
	}
	fmt.Fprintf(w, "// Code generated by \"%s\"; DO NOT EDIT.\n\n", cmdstr)
	fmt.Fprintf(w, "package %s\n", pkg)
	if len(imports) > 0 {
		fmt.Fprint(w, "import (\n")
		for _, imp := range imports {
			fmt.Fprintf(w, "\t%q\n", imp)
		}
		fmt.Fprint(w, ")\n")
	}
}

// ExcludeFile returns true if the given file is on the exclude list.
func ExcludeFile(pkg *packages.Package, file *ast.File, exclude ...string) bool {
	fpos := pkg.Fset.Position(file.FileStart)
	_, fname := filepath.Split(fpos.Filename)
	for _, ex := range exclude {
		if fname == ex {
			return true
		}
	}
	return false
}

// Inspect goes through all of the files in the given package,
// except those listed in the exclude list, and calls the given
// function on each node. The bool return value from the given function
// indicates whether to continue traversing down the AST tree
// of that node and look at its children. If a non-nil error value
// is returned by the given function, the traversal of the tree is
// stopped and the error value is returned.
func Inspect(pkg *packages.Package, f func(n ast.Node) (bool, error), exclude ...string) error {
	for _, file := range pkg.Syntax {
		if ExcludeFile(pkg, file, exclude...) {
			continue
		}
		var terr error
		var terrNode ast.Node
		ast.Inspect(file, func(n ast.Node) bool {
			if terr != nil {
				return false
			}
			cont, err := f(n)
			if err != nil {
				terr = err
				terrNode = n
			}
			return cont
		})
		if terr != nil {
			return fmt.Errorf("generate.Inspect: error while calling inspect function for node %v: %w", terrNode, terr)
		}
	}
	return nil
}

// Filepath returns the filepath of a file in the given
// package with the given filename relative to the package.
func Filepath(pkg *packages.Package, filename string) string {
	dir := "."
	if len(pkg.Syntax) > 0 {
		dir = filepath.Dir(pkg.Fset.Position(pkg.Syntax[0].FileStart).Filename)
	}
	return filepath.Join(dir, filename)
}

// Write writes the given bytes to the given filename after
// applying goimports using the given options.
func Write(filename string, src []byte, opt *imports.Options) error {
	b, ferr := Format(filename, src, opt)
	// we still write file even if formatting failed, as it is still useful
	// then we handle error later
	werr := os.WriteFile(filename, b, 0666)
	if werr != nil {
		return fmt.Errorf("generate.Write: error writing file: %w", werr)
	}
	if ferr != nil {
		return fmt.Errorf("generate.Write: error formatting code: %w", ferr)
	}
	return nil
}

// Format returns the given bytes with goimports applied.
// It wraps [imports.Process] by wrapping any error with
// additional context.
func Format(filename string, src []byte, opt *imports.Options) ([]byte, error) {
	b, err := imports.Process(filename, src, opt)
	if err != nil {
		// Should never happen, but can arise when developing code.
		// The user can compile the output to see the error.
		return src, fmt.Errorf("internal/programmer error: generate.Format: invalid Go generated: %w; compile the package to analyze the error", err)
	}
	return b, nil
}
